#include <getopt.h>
#include <iostream>
#include <vector>
#include <string>
#include <random>
#include <numeric>
#include <cmath>
#include "DAG.hpp"
#include "bipartite.hpp"
#include "prediction.hpp"
#include "Lazy.hpp"
using namespace std;
int main(int argc, char* argv[]) {
    int n = -1, m = -1, k = -1, c = -1;
    double epsilon = -1.0;
    string algorithm;

    const char* const short_opts = "n:m:k:c:e:a:";
    const option long_opts[] = {
        {"nodes", required_argument, nullptr, 'n'},
        {"edges", required_argument, nullptr, 'm'},
        {"kth", required_argument, nullptr, 'k'},
        {"constant", required_argument, nullptr, 'c'},
        {"epsilon", required_argument, nullptr, 'e'},
        {"algorithm", required_argument, nullptr, 'a'},
        {nullptr, no_argument, nullptr, 0}
    };
    while (true) {
        const auto opt = getopt_long(argc, argv, short_opts, long_opts, nullptr);
        if (opt == -1) break;
        switch (opt) {
            case 'n': n = stoi(optarg); break;
            case 'm': m = stoi(optarg); break;
            case 'k': k = stoi(optarg); break;
            case 'c': c = stoi(optarg); break;
            case 'e': epsilon = stod(optarg); break;
            case 'a': algorithm = optarg; break;
            default:
                cerr << "Invalid argument.\n";
                return 1;
        }
    }
    if (algorithm != "rand" && algorithm != "det" && algorithm != "vote") {
        cerr << "Algorithm must be one of: rand, det, vote.\n";
        return 1;
    }
    if (algorithm == "vote") {
        if (n <= 0 || c <= 0 || epsilon < 0.0) {
            cerr << "vote requires -n, -c, -e\n";
            return 1;
        }
        k = sqrt(n);
    }
    else {
        if (n <= 0 || m <= 0 || k <= 0) {
            cerr << algorithm << " requires -n, -m, -k\n";
            return 1;
        }
    }
    vector<int> A(n);
    iota(A.begin(), A.end(), 0);
    mt19937 gen(random_device{}());
    if (algorithm == "vote") {
        double r = pow(n, 0.75);
        int s = 4 * r;
        k = sqrt(n);
        shuffle(A.begin(), A.end(), gen);
        lazySelect selector(r, k, s, epsilon);
        int result = selector.votingselect(A, c);
        bool correct = (result == n/2 - 1);
        cout<< "[vote] result: " << result 
            << ", weak: " << selector.GetWeakCnt()
            << ", strong: " << selector.GetStrongCnt()
            << ", correct: " << correct << endl;
    }
    else {
        DirectedAcyclicGraph dag(n, m);
        cout << "Preprocessing start\n";
        dag.BuildCompleteDAG();
        BipartiteGraph BP(n);
        BP.convertDAGtoBP(dag.GetDAG(), dag.GetEdges());
        BP.Match();
        int width = n - BP.getMaxMatching();
        vector<bool> active(n, true);
        vector<vector<int>> chains = dag.ChainDecomposition(BP.getMatch(), BP.getInvMatch());
        cout << "Preprocessing end\n";
        SelectWithPrediction selector;
        shuffle(A.begin(), A.end(), gen);
        int result = (algorithm == "rand")
                   ? selector.RandomizedSelect(A, k, active, chains, dag)
                   : selector.DeterministicSelect(A, k, active, chains, dag, 1);
        cout << "[" << algorithm << "] result: " << result
             << ", width: " << width
             << ", comparisons: " << selector.GetComparisonCount() << endl;
    }

    return 0;
}